# coding=utf-8
# Copyright 2024-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a tool to generate `src/huggingface_hub/inference/_generated/types`."""

import argparse
import re
from pathlib import Path
from typing import Dict, List, Literal, NoReturn, Optional

import libcst as cst
from helpers import check_and_update_file_content, format_source_code


huggingface_hub_folder_path = Path(__file__).parents[1] / "src" / "huggingface_hub"
INFERENCE_TYPES_FOLDER_PATH = huggingface_hub_folder_path / "inference" / "_generated" / "types"
MAIN_INIT_PY_FILE = huggingface_hub_folder_path / "__init__.py"
REFERENCE_PACKAGE_EN_PATH = (
    Path(__file__).parents[1] / "docs" / "source" / "en" / "package_reference" / "inference_types.md"
)
REFERENCE_PACKAGE_KO_PATH = (
    Path(__file__).parents[1] / "docs" / "source" / "ko" / "package_reference" / "inference_types.md"
)

IGNORE_FILES = [
    "__init__.py",
    "base.py",
]

BASE_DATACLASS_REGEX = re.compile(
    r"""
    ^@dataclass
    \nclass\s(\w+):\n
""",
    re.VERBOSE | re.MULTILINE,
)

INHERITED_DATACLASS_REGEX = re.compile(
    r"""
    ^@dataclass_with_extra
    \nclass\s(\w+)\(BaseInferenceType\):
""",
    re.VERBOSE | re.MULTILINE,
)

TYPE_ALIAS_REGEX = re.compile(
    r"""
    ^(?!\s) # to make sure the line does not start with whitespace (top-level)
    (\w+)
    \s*=\s*
    (.+)
    $
    """,
    re.VERBOSE | re.MULTILINE,
)
OPTIONAL_FIELD_REGEX = re.compile(r": Optional\[(.+)\]$", re.MULTILINE)


INIT_PY_HEADER = """
# This file is auto-generated by `utils/generate_inference_types.py`.
# Do not modify it manually.
#
# ruff: noqa: F401

from .base import BaseInferenceType
"""

# Regex to add all dataclasses to ./src/huggingface_hub/__init__.py
MAIN_INIT_PY_REGEX = re.compile(
    r"""
\"inference\._generated\.types\":\s*\[ # module name
    (.*?) # all dataclasses listed
\] # closing bracket
""",
    re.MULTILINE | re.VERBOSE | re.DOTALL,
)


# List of classes that are shared across multiple modules
# This is used to fix the naming of the classes (to make them unique by task)
SHARED_CLASSES = [
    "BoundingBox",
    "ClassificationOutputTransform",
    "ClassificationOutput",
    "GenerationParameters",
    "TargetSize",
    "EarlyStoppingEnum",
]

REFERENCE_PACKAGE_EN_CONTENT = """
<!--⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<!--⚠️ Note that this file is auto-generated by `utils/generate_inference_types.py`. Do not modify it manually.-->


# Inference types

This page lists the types (e.g. dataclasses) available for each task supported on the Hugging Face Hub.
Each task is specified using a JSON schema, and the types are generated from these schemas - with some customization
due to Python requirements.
Visit [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)
to find the JSON schemas for each task.

This part of the lib is still under development and will be improved in future releases.


{types}
"""

REFERENCE_PACKAGE_KO_CONTENT = """
<!--⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<!--⚠️ Note that this file is auto-generated by `utils/generate_inference_types.py`. Do not modify it manually.-->


# 추론 타입[[inference-types]]

이 페이지에는 Hugging Face Hub에서 지원하는 타입(예: 데이터 클래스)이 나열되어 있습니다.
각 작업은 JSON 스키마를 사용하여 지정되며, 이러한 스키마에 의해서 타입이 생성됩니다. 이때 Python 요구 사항으로 인해 일부 사용자 정의가 있을 수 있습니다.

각 작업의 JSON 스키마를 확인하려면 [@huggingface.js/tasks](https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks)를 확인하세요.

라이브러리에서 이 부분은 아직 개발 중이며, 향후 릴리즈에서 개선될 예정입니다.


{types}
"""


def _replace_class_name(content: str, cls: str, new_cls: str) -> str:
    """
    Replace the class name `cls` with the new class name `new_cls` in the content.
    """
    pattern = rf"""
        (?<![\w'"])
        (['"]?)
        {cls}
        (['"]?)
        (?![\w'"])
    """

    def replacement(m):
        quote_start = m.group(1) or ""
        quote_end = m.group(2) or ""
        return f"{quote_start}{new_cls}{quote_end}"

    content = re.sub(pattern, replacement, content, flags=re.VERBOSE)
    return content


def _inherit_from_base(content: str) -> str:
    content = content.replace(
        "\nfrom dataclasses import",
        "\nfrom .base import BaseInferenceType, dataclass_with_extra\nfrom dataclasses import",
    )
    content = BASE_DATACLASS_REGEX.sub(r"@dataclass_with_extra\nclass \1(BaseInferenceType):\n", content)
    return content


def _delete_empty_lines(content: str) -> str:
    return "\n".join([line for line in content.split("\n") if line.strip()])


def _fix_naming_for_shared_classes(content: str, module_name: str) -> str:
    for cls in SHARED_CLASSES:
        # No need to fix the naming of a shared class if it's not used in the module
        if cls not in content:
            continue
        # Update class definition
        # Very hacky way to build "AudioClassificationOutputElement" instead of "ClassificationOutput"
        new_cls = "".join(part.capitalize() for part in module_name.split("_"))
        if "Classification" in new_cls:
            # to avoid "ClassificationClassificationOutput"
            new_cls += cls.removeprefix("Classification")
        else:
            new_cls += cls
        if new_cls.endswith("ClassificationOutput"):
            # to get "AudioClassificationOutputElement"
            new_cls += "Element"
        content = _replace_class_name(content, cls, new_cls)

    return content


def _fix_text2text_shared_parameters(content: str, module_name: str) -> str:
    if module_name in ("summarization", "translation"):
        content = content.replace(
            "Text2TextGenerationParameters",
            f"{module_name.capitalize()}GenerationParameters",
        )
        content = content.replace(
            "Text2TextGenerationTruncationStrategy",
            f"{module_name.capitalize()}GenerationTruncationStrategy",
        )
    return content


def _make_optional_fields_default_to_none(content: str):
    lines = []
    for line in content.split("\n"):
        if "Optional[" in line and not line.endswith("None"):
            line += " = None"

        lines.append(line)

    return "\n".join(lines)


def _list_dataclasses(content: str) -> List[str]:
    """List all dataclasses defined in the module."""
    return INHERITED_DATACLASS_REGEX.findall(content)


def _list_type_aliases(content: str) -> List[str]:
    """List all type aliases defined in the module."""
    return [alias_class for alias_class, _ in TYPE_ALIAS_REGEX.findall(content)]


class DeprecatedRemover(cst.CSTTransformer):
    def is_deprecated(self, docstring: Optional[str]) -> bool:
        """Check if a docstring contains @deprecated."""
        return docstring is not None and "@deprecated" in docstring.lower()

    def get_docstring(self, body: List[cst.BaseStatement]) -> Optional[str]:
        """Extract docstring from a body of statements."""
        if not body:
            return None
        first = body[0]
        if isinstance(first, cst.SimpleStatementLine):
            expr = first.body[0]
            if isinstance(expr, cst.Expr) and isinstance(expr.value, cst.SimpleString):
                return expr.value.evaluated_value
        return None

    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> Optional[cst.ClassDef]:
        """Handle class definitions - remove if deprecated."""
        docstring = self.get_docstring(original_node.body.body)
        if self.is_deprecated(docstring):
            return cst.RemoveFromParent()

        new_body = []
        statements = list(updated_node.body.body)
        i = 0
        while i < len(statements):
            stmt = statements[i]

            # Check if this is a field (AnnAssign)
            if isinstance(stmt, cst.SimpleStatementLine) and isinstance(stmt.body[0], cst.AnnAssign):
                # Look ahead for docstring
                next_docstring = None
                if i + 1 < len(statements):
                    next_docstring = self.get_docstring([statements[i + 1]])

                if self.is_deprecated(next_docstring):
                    i += 2  # Skip both the field and its docstring
                    continue

            new_body.append(stmt)
            i += 1

        if not new_body:
            return cst.RemoveFromParent()

        return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body))


def _clean_deprecated_fields(content: str) -> str:
    """Remove deprecated classes and fields using libcst."""
    source_tree = cst.parse_module(content)
    transformer = DeprecatedRemover()
    modified_tree = source_tree.visit(transformer)
    return modified_tree.code


def fix_inference_classes(content: str, module_name: str) -> str:
    content = _inherit_from_base(content)
    content = _delete_empty_lines(content)
    content = _fix_naming_for_shared_classes(content, module_name)
    content = _fix_text2text_shared_parameters(content, module_name)
    content = _make_optional_fields_default_to_none(content)
    return content


def create_init_py(dataclasses: Dict[str, List[str]]):
    """Create __init__.py file with all dataclasses."""
    content = INIT_PY_HEADER
    content += "\n"
    content += "\n".join(
        [f"from .{module} import {', '.join(dataclasses_list)}" for module, dataclasses_list in dataclasses.items()]
    )
    return content


def add_dataclasses_to_main_init(content: str, dataclasses: Dict[str, List[str]]):
    dataclasses_list = sorted({cls for classes in dataclasses.values() for cls in classes})
    dataclasses_str = ", ".join(f"'{cls}'" for cls in dataclasses_list)

    return MAIN_INIT_PY_REGEX.sub(f'"inference._generated.types": [{dataclasses_str}]', content)


def generate_reference_package(dataclasses: Dict[str, List[str]], language: Literal["en", "ko"]) -> str:
    """Generate the reference package content."""

    per_task_docs = []
    for task in sorted(dataclasses.keys()):
        lines = [f"[[autodoc]] huggingface_hub.{cls}" for cls in sorted(dataclasses[task])]
        lines_str = "\n\n".join(lines)
        if language == "en":
            # e.g. '## audio_classification'
            per_task_docs.append(f"\n## {task}\n\n{lines_str}\n\n")
        elif language == "ko":
            # e.g. '## audio_classification[[huggingface_hub.AudioClassificationInput]]'
            per_task_docs.append(f"\n## {task}[[huggingface_hub.{sorted(dataclasses[task])[0]}]]\n\n{lines_str}\n\n")
        else:
            raise ValueError(f"Language {language} is not supported.")

    template = REFERENCE_PACKAGE_EN_CONTENT if language == "en" else REFERENCE_PACKAGE_KO_CONTENT
    return template.format(types="\n".join(per_task_docs))


def check_inference_types(update: bool) -> NoReturn:
    """Check and update inference types.

    This script is used in the `make style` and `make quality` checks.
    """
    dataclasses = {}
    aliases = {}
    for file in INFERENCE_TYPES_FOLDER_PATH.glob("*.py"):
        if file.name in IGNORE_FILES:
            continue
        content = file.read_text()
        content = _clean_deprecated_fields(content)
        fixed_content = fix_inference_classes(content, module_name=file.stem)
        formatted_content = format_source_code(fixed_content)
        dataclasses[file.stem] = _list_dataclasses(formatted_content)
        aliases[file.stem] = _list_type_aliases(formatted_content)
        check_and_update_file_content(file, formatted_content, update)

    all_classes = {module: dataclasses[module] + aliases[module] for module in dataclasses.keys()}
    init_py_content = create_init_py(all_classes)
    init_py_content = format_source_code(init_py_content)
    init_py_file = INFERENCE_TYPES_FOLDER_PATH / "__init__.py"
    check_and_update_file_content(init_py_file, init_py_content, update)

    main_init_py_content = MAIN_INIT_PY_FILE.read_text()
    updated_main_init_py_content = add_dataclasses_to_main_init(main_init_py_content, all_classes)
    updated_main_init_py_content = format_source_code(updated_main_init_py_content)
    check_and_update_file_content(MAIN_INIT_PY_FILE, updated_main_init_py_content, update)
    reference_package_content_en = generate_reference_package(dataclasses, "en")
    check_and_update_file_content(REFERENCE_PACKAGE_EN_PATH, reference_package_content_en, update)

    reference_package_content_ko = generate_reference_package(dataclasses, "ko")
    check_and_update_file_content(REFERENCE_PACKAGE_KO_PATH, reference_package_content_ko, update)

    print("✅ All good! (inference types)")
    exit(0)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--update",
        action="store_true",
        help=(
            "Whether to re-generate files in `./src/huggingface_hub/inference/_generated/types/` if a change is detected."
        ),
    )
    args = parser.parse_args()

    check_inference_types(update=args.update)
